# ---------------------------------------------------------------------
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# ---------------------------------------------------------------------

from __future__ import annotations

import torch

from qai_hub_models.models._shared.deeplab.model import NUM_CLASSES, DeepLabV3Model
from qai_hub_models.utils.asset_loaders import (
    CachedWebModelAsset,
    SourceAsRoot,
    find_replace_in_repo,
)

MODEL_ID = __name__.split(".")[-2]
MODEL_ASSET_VERSION = 3
# Weights downloaded from https://github.com/quic/aimet-model-zoo/releases/download/phase_2_january_artifacts/deeplab-mobilenet.pth.tar
DEEPLABV3_WEIGHTS = "deeplab-mobilenet.pth.tar"
MOBILENET_WEIGHTS = "mobilenet_v2-6a65762b.pth"
DEEPLABV3_SOURCE_REPOSITORY = "https://github.com/jfzhang95/pytorch-deeplab-xception"
DEEPLABV3_SOURCE_REPO_COMMIT = "9135e104a7a51ea9effa9c6676a2fcffe6a6a2e6"
BACKBONE = "mobilenet"


class DeepLabV3PlusMobilenet(DeepLabV3Model):
    """Exportable DeepLabV3_Plus_MobileNet image segmentation applications, end-to-end."""

    @classmethod
    def from_pretrained(cls, normalize_input: bool = True) -> DeepLabV3PlusMobilenet:
        model = _load_deeplabv3_source_model()
        dst = CachedWebModelAsset.from_asset_store(
            MODEL_ID, MODEL_ASSET_VERSION, DEEPLABV3_WEIGHTS
        ).fetch()
        checkpoint = torch.load(
            dst, map_location=torch.device("cpu"), weights_only=False
        )
        model.load_state_dict(checkpoint["state_dict"])

        return cls(model, normalize_input)


def _load_deeplabv3_source_model() -> torch.nn.Module:
    # Load DeepLabV3 model from the source repository using the given weights.
    # Returns <source repository>.modeling.deeplab.DeepLab
    with SourceAsRoot(
        DEEPLABV3_SOURCE_REPOSITORY,
        DEEPLABV3_SOURCE_REPO_COMMIT,
        MODEL_ID,
        MODEL_ASSET_VERSION,
        keep_sys_modules=True,
    ) as repo_path:
        dst = CachedWebModelAsset.from_asset_store(
            MODEL_ID, MODEL_ASSET_VERSION, DEEPLABV3_WEIGHTS
        ).fetch()
        find_replace_in_repo(
            repo_path,
            ["modeling/backbone/mobilenet.py"],
            "model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')",
            f"torch.load('{dst!s}', map_location='cpu', weights_only=False)",
        )
        # necessary import. `modeling.deeplab` comes from the DeepLabV3 repo.
        from modeling.deeplab import DeepLab

        return DeepLab(backbone=BACKBONE, sync_bn=False, num_classes=NUM_CLASSES)
